#!/usr/bin/env python

"""
Expects rows to be units (e.g. genes/KOs/etc.), and columns to be samples.

This script normalizes a table for sampling depth by either coverage per million (CPM) or based on the median-ratio 
method as performed in DESeq2. But unlike DESeq2, we don't care here if there are floats in there. 

I initially wrote this for normalizing metagenomic coverage data, like gene-level coverage, or summed KO coverages. 
These are normalized for gene-length already because they are "coverages", but they are not yet normalized 
for sampling depth – which is where this script comes in. 

I also found myself wanting this because I wanted to do differential abundance testing of coverages
of KO terms. DESeq2 doesn't require normalizing for gene-length because it is the same unit being analyzed
across all samples – the same gene, so the same size. However, after grouping genes into their KO annotations,
(which we may need to compare across samples that don't all share the same underlying assembly or genes), 
they no longer all represent the same units across all samples. It is because of this I decided to stick with 
gene-level coverages (which are normalized for gene-length), and then sum those values based on KO annotations.

The CPM (coverage per million) normalization is just like a percent, except scaled to 1 million instead of 100.
So each row's entry (e.g. gene/KO/etc.) is the proportion out of 1 million for that column (sample), 
and each column will sum to 1 million.

The median-ration normalization method (MR) was initially described in this paper 
(http://dx.doi.org/10.1186/gb-2010-11-10-r106; e.q. 5), and this site is super-informative in general 
about the DESeq2 process overall, and helped me understand the normalizaiton process better to implement it: 
https://hbctraining.github.io/DGE_workshop/lessons/02_DGE_count_normalization.html. Columns will not sum to 
the same amount when the median-ratio method is applied.
"""

import os
import sys
import argparse
import pandas as pd
import numpy as np
from scipy.stats.mstats import gmean

parser = argparse.ArgumentParser(description="This script normalizes a coverage table for sampling depth based on either \
                                             coverage per million (CPM) or the median-ratio method (MR) as performed \
                                             in DESeq2. See note at top of script for more info. It expects a \
                                             tab-delimited table with samples as columns and units (e.g. genes/KOs/etc.) \
                                             as rows. For version info, run `bit-version`.")

required = parser.add_argument_group('required arguments')

required.add_argument("-i", "--input-table", help="Input tab-delimited table", action="store", required=True)

parser.add_argument("-n", "--normalization-method", help='Desired normalization method of either \
                    "CPM" as in coverage per million, or "MR" as in median-ratio as performed in DESeq2. \
                    See note at top of program for more info. (default: "CPM")', choices=["CPM", "MR"], \
                    action="store", default="CPM")

parser.add_argument("-o", "--output-table", help='Output filename (default: "Normalized.tsv")', action="store", default="Normalized.tsv")

if len(sys.argv)==1:
    parser.print_help(sys.stderr)
    sys.exit(0)

args = parser.parse_args()

################################################################################

tab = pd.read_csv(args.input_table, sep = "\t", index_col = 0, low_memory = False)


## removing columns if they have all zeroes in them prior to normalization to avoid problems (will put them back)
column_sums = tab.sum()

# getting all column names in order so can rearrange afterwards
ordered_columns = tab.columns.tolist()

# getting column names of those with all zeroes
zero_column_names = column_sums[column_sums == 0].index.tolist()

tab.drop(zero_column_names, axis = 1, inplace = True)


if args.normalization_method == "CPM":

    norm_tab = tab / tab.sum() * 1000000

else:

    ## calculating size factors
    # getting geometric means for each row
    with np.errstate(divide = 'ignore'):
        geomeans = gmean(tab, axis = 1)

    # getting ratios of gene values to geometric means
    ratios_tab = (tab.T / geomeans).T

    sizeFactors = ratios_tab[geomeans > 0].median().to_list()

    # dividing by size factors
    norm_tab = tab / sizeFactors


## adding back on columns with all zeroes
if len(zero_column_names) > 0:
    for col in zero_column_names:
        norm_tab[col] = 0.0

# reordering to match input
norm_tab = norm_tab[ordered_columns]

# writing out normalized table
norm_tab.to_csv(args.output_table, sep = "\t")
